What Happens To BERT Embeddings During Fine-tuning?

Codebase: https://github.com/r05323028/What_happens_to_bert_embeddings_during_fintuning

In this notebook, we'll try to reproduce What Happens To BERT Embeddings During Fine-tuning? which was accepted by EMNLP2020, Proceedings of the Third BlackboxNLP Workshop on Analyzing and Interpreting Neural Networks for NLP

Todo list

Brief Introduction to BERT

BERT is a stack encoder inspired by Attention is all you need. The main idea of BERT is to train a large encoder contains huge informations. Then, several NLP tasks can use BERT to finetune downstream decoder to solve specific tasks.

qro152np50704pp7n8o928q5p9p8r908.jpg

For detials of how to pretrain or finetune BERT, you can see the original paper.

Layerwise Analysis of BERT

The first time I saw people do layerwise analysis of BERT was at CIKM2019, Beijing. I went to this conference with my colleague and saw this paper - How Does BERT Answer Questions? A Layer-Wise Analysis of Transformer Representations. It performs PCA to project word representations of all layers to a 2D space after BERT inferernces questions and extracts answers. Then, we can see that bottom layers of BERT contain low-level informations such as Topics, POS, Dependencies and top layers contain high-level informations such as Question-Fact matching, Answering extraction.

截圖 2021-01-31 下午5.41.22.png

The authors build a cool interactive website. If you are interesting in this, you can go to this page.

Load Data & Model

In this section, we load wiki dataset, bert-base-uncased which is a pretrained model announced by Google, bert-mnli which was finetuned by glue/mnli dataset and bert-squad which is finetuned by squad.

Representational Similarity Analysis (RSA)

We used pretrained model & finetuned models to get hidden states of layers and compare their cosine similarity.

Due to the above figure, we can conclude that

Structural Probe

Structural probe is a method to evaluate whether a word representation model learns syntax structure in paragraphs. The main idea is, it supposes that syntax tree structures can be remained after linear projecting. The following figure shows this concept intuitively.

header.png

So, how do we evaluate it? the answer is, we can train a probe model to predict the number of edges between every pair tokens. In order to do gradient descent in Euclidean space, we define the number of edges between every pair tokens as distance, $d(w_i, w_j)$.

distances.png

The following figure shows that we can use a probe model to project word representations to the subspace which persists the syntax tree structure of the sentence if the language model really learned syntax.

space-dist.png

As stated above, we used a linear projecting to map hidden state pairs to a subspace.

$$ {\bf B}({\bf h}_i - {\bf h}_j) $$

where, ${\bf B}$ is the linear transformation matrix, ${\bf h}$ is the hidden states. Then, we can define $d_{B}$ as the distance after linear projecting.

$$ d_{{\bf B}}({\bf h}_i, {\bf h}_j) = ({\bf B}({\bf h}_i - {\bf h}_j))^T ({\bf B}({\bf h}_i - {\bf h}_j)) $$

Finally, we can perform gradient descent to minimize the loss:

$$ \min_{{\bf B}} \sum_{l} \frac{1}{|s_l|^2} \sum_{i, j} (d(w_i, w_j) - d_{{\bf B}}({\bf h}_i, {\bf h}_j)) $$

Due to the above figure, we can conclude that,

References